# Author: Kunal Sangani, ksangani@g.harvard.edu

import numpy as np
import pandas as pd
import pickle
from scipy.interpolate import interp1d
from scipy.stats import norm
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import tikzplotlib
import os

FIG_SIZE_1x2 = (9,4)
FIG_SIZE_2x1 = (5,8)
FIG_SIZE_2x2 = (10,8)

# Read data from Darwinian returns to scale paper
# Column number refers to different permutations:
# % col == 1 (mu_bar = 1.045, delta_bar = delta0)
# %     == 2 (mu_bar = 1.045, delta_bar = 1.045) 
# %     == 3 (mu_bar = 1.090, delta_bar = delta0)
# %     == 4 (mu_bar = 1.090, delta_bar = 1.090)
# mu_bar = 1.045 gives dlogY/dlogL = 0.13 (Bartelme et al 2019 trade literature)
# mu_bar = 1.090 gives dlogY/dlogL = 0.30 (Jones 2019 growth literature)
# 2,4 refer to efficient entry; 1,3 refer to efficient selection
# (Note: Subtract 1 from index since Python 0-indexed)
# Default set to efficient selection, mu_bar = 1.045
# With CUSTOMMARKUP file, the mu_bar is instead 1.201 which comes from oligopoly simulation
DARWINIAN_REPL_PATH = './data/Prod_Ups_ODE_repl.xlsx'
DARWINIAN_CUSTOMMARKUP_PATH = './data/Prod_Ups_ODE_custommarkup.xlsx'
def read_darwinian_data(pull_col=0, filename=DARWINIAN_REPL_PATH):
	xls = pd.ExcelFile(filename)
	p = pd.read_excel(filename, sheet_name=xls.sheet_names, header=None)
	
	# Some sheets have an extra row: This drops the first row(s) in sheets to even out
	lengths = [len(p[sheet_name]) for sheet_name in p]
	lengths = np.array(lengths) - min(lengths)
	for i,sheet_name in enumerate(p):
		if lengths[i] > 0:
			p[sheet_name] = p[sheet_name].drop(list(range(lengths[i])))
	lengths = [len(p[sheet_name]) for sheet_name in p]

	return list(p.keys()),[p[sheet_name][pull_col].to_numpy() for sheet_name in p]

# Returns shares, markups, and passthroughs from Darwinian replication data
def get_key_darwinian_data(pull_col=0, filename=DARWINIAN_REPL_PATH):
	keys,darw_dat = read_darwinian_data(pull_col, filename)
	shares_ind = keys.index('lambda')
	markups_ind = keys.index('mu')
	passthroughs_ind = keys.index('rho')
	prod_ind = keys.index('A')
	output_ind = keys.index('y')
	return darw_dat[shares_ind],darw_dat[markups_ind],darw_dat[passthroughs_ind],darw_dat[prod_ind],darw_dat[output_ind]

# Generate markups to boundary (harmonic) average mu = mu_bar
def generate_markups(rho, shares, mu_bar):
	filename = 'data/kimball_store/' + str(mu_bar) + '.store'
	if os.path.exists(filename):
		return np.fromfile(filename)
	theta = np.linspace(0, 1, len(rho)+2)[1:-1]
	rho_fn = interp1d(np.append([0], theta),np.append([1], rho))
	dloglambda = np.diff(np.log(shares)) / np.diff(theta)
	dloglambda_fn = interp1d(np.append([0], theta), np.append([dloglambda[0], dloglambda[0]], dloglambda))
	dmu = lambda theta,mu: mu*(mu-1)*(1-rho_fn(theta))/rho_fn(theta) * dloglambda_fn(theta)
	def mu_from_mu_0(mu_0, mu_bar):
		print('Trying: ', mu_0)
		mu = rk4_integrate(mu_0, theta, dmu)
		mu_bar_pred = 1/(weighted_exp(1/mu, shares))
		print('Predicted mu_bar: ', mu_bar_pred)
		return (mu_bar - mu_bar_pred)**2
	res = minimize(mu_from_mu_0, 1.001, args=(mu_bar), tol=1E-5)
	mu = rk4_integrate(res.x, theta, dmu)
	mu.tofile(filename)
	return mu

# Returns bounds on integral using Reihmann rectangles
def arr_integration_bounds(arr):
	return sum(arr[1:])/(len(arr)-1),sum(arr[:-1])/(len(arr)-1)

# Returns bounds on an expectation of the form E_{weights}[arr]
def weighted_exp_bounds(arr, weights):
	num_low,num_high = arr_integration_bounds(arr * weights)
	den_low,den_high = arr_integration_bounds(weights)
	return num_low/den_high, num_high/den_low

# Shows percent deviation for integral
def explore_deviation_weighted_exp(arr, weights):
	base = weighted_exp(arr, weights)
	bound_low, bound_high = weighted_exp_bounds(arr, weights)
	print(bound_low/base - 1, bound_high/base - 1)

# E_{weights}[arr]
def weighted_exp(arr, weights):
	return sum(arr * weights) / sum(weights)

# Cov_{weights}(arr1, arr2)
def weighted_cov(arr1, arr2, weights):
	return weighted_exp(arr1 * arr2, weights) - weighted_exp(arr1, weights) * weighted_exp(arr2, weights)

#############################################
# RK4 integration and Pareto functions
#############################################

def rk4step(y_last,theta_last,step_size,f):
	k1 = step_size * f(theta_last, y_last)
	k2 = step_size * f(theta_last + step_size/2, y_last + k1/2)
	k3 = step_size * f(theta_last + step_size/2, y_last + k2/2)
	k4 = step_size * f(theta_last + step_size, y_last + k3)
	return y_last + (1/6)*(k1 + 2*k2 + 2*k3 + k4)

def rk4_integrate(y_0, theta, f):
	y = np.ones(len(theta)) * y_0
	for i in range(1,len(theta)):
		y[i] = rk4step(y[i-1], theta[i-1], theta[i]-theta[i-1], f)
	return y

def generate_distributions_from_pareto(pareto_shape, mu_0=1.001, share_0=1E-14, keys=None, darw_dat=None):
	return generate_distributions_from_productivity_dist(generate_pareto_dist(pareto_shape), mu_0=mu_0, share_0=share_0, keys=keys, darw_dat=darw_dat)

def generate_distributions_from_lognormal(sigma, mu_0=1.001, share_0=1E-14, keys=None, darw_dat=None):
	return generate_distributions_from_productivity_dist(generate_lognormal_dist(sigma), mu_0=mu_0, share_0=share_0, keys=keys, darw_dat=darw_dat)

def generate_pareto_dist(pareto_shape):
	return lambda theta: 1/(pareto_shape*(1-theta))

def generate_lognormal_dist(sigma):
	return lambda theta: sigma/(norm.pdf(norm.ppf(theta)))

def generate_distributions_from_productivity_dist(productivity_dist, mu_0=1.001, share_0=1E-14, keys=None, darw_dat=None):
	if darw_dat is None:
		keys,darw_dat = read_darwinian_data()
	rho = darw_dat[keys.index('rho')]
	theta = np.linspace(0,1,num=len(rho)+2)[1:-1]
	rho_fn = interp1d(theta, rho, fill_value='extrapolate')
	#
	dlogmu = lambda theta,mu: (1-rho_fn(theta)) * productivity_dist(theta)
	mu = np.exp(rk4_integrate(np.log(mu_0), theta, dlogmu))
	#
	mu_fn = interp1d(theta, mu, fill_value='extrapolate')
	#
	dloglambda = lambda theta,share: rho_fn(theta)/(mu_fn(theta)-1) * productivity_dist(theta)
	shares = np.exp(rk4_integrate(np.log(share_0), theta, dloglambda))
	print(shares)
	#
	plt.figure()
	plt.scatter(theta,rho)
	plt.figure()
	plt.scatter(theta,mu)
	plt.figure()
	plt.scatter(theta,shares)
	plt.show()

	return shares,mu,rho

#############################################
# Save and load object using pickle module
#############################################

def load_object_from_file(filename):
	try:
		with open(filename, 'rb') as file:
			return pickle.load(file)
	except EnvironmentError:
		return None

def save_object_to_file(obj, filename):
	with open(filename, 'wb') as file:
		pickle.dump(obj, file)
		return True
	return False

def plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, parameter_space, xlabel, savefile=None, reverse=False, overall_flattening=False):
	real_rigid_contrib_wage = calvo_wage_slopes / real_rigid_wage_slopes
	misalloc_contrib_wage = real_rigid_wage_slopes / phil_wage_slopes
	real_rigid_contrib_price = calvo_price_slopes / real_rigid_price_slopes
	misalloc_contrib_price = real_rigid_price_slopes / phil_price_slopes
	overall_flattening_wage = calvo_wage_slopes / phil_wage_slopes
	overall_flattening_price = calvo_price_slopes / phil_price_slopes
	
	fig,axs = plt.subplots(1,2, figsize=FIG_SIZE_1x2)

	axs[0].plot(parameter_space, calvo_wage_slopes, label='Sticky price channel alone', color='green')
	axs[0].plot(parameter_space, real_rigid_wage_slopes, label='Real rigidities included', color='darkorange')
	axs[0].plot(parameter_space, phil_wage_slopes, label='Misallocation channel included', color='blue')
	axs[0].set_ylabel('Phillips curve slope')
	#axs[0][0].set_yscale('log')
	axs[0].set_title('Wage Phillips curve')
	if reverse:
		axs[0].set_xlim(axs[0].get_xlim()[::-1])
	#axs[0].legend(loc='lower left')
	axs[0].legend()
	axs[0].set_xlabel(xlabel)
	#
	axs[1].plot(parameter_space, calvo_price_slopes, label='Sticky price channel alone', color='green')
	axs[1].plot(parameter_space, real_rigid_price_slopes, label='Real rigidities included', color='darkorange')
	axs[1].plot(parameter_space, phil_price_slopes, label='Misallocation channel included', color='blue')
	#axs[0][1].set_yscale('log')
	axs[1].set_title('CPI Phillips curve')
	axs[1].set_xlabel(xlabel)
	if reverse:
		axs[1].set_xlim(axs[1].get_xlim()[::-1])
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savefile is not None:
		fig.suptitle(None)
		if not reverse:
			tikzplotlib.clean_figure()
		tikzplotlib.save(savefile, extra_axis_parameters=['PlotStyle'])
	return axs

def plot_slopes_simple(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, parameter_space, xlabel, savefile=None, reverse=False, overall_flattening=False):
	real_rigid_contrib_wage = calvo_wage_slopes / real_rigid_wage_slopes
	misalloc_contrib_wage = real_rigid_wage_slopes / phil_wage_slopes
	real_rigid_contrib_price = calvo_price_slopes / real_rigid_price_slopes
	misalloc_contrib_price = real_rigid_price_slopes / phil_price_slopes
	overall_flattening_wage = calvo_wage_slopes / phil_wage_slopes
	overall_flattening_price = calvo_price_slopes / phil_price_slopes
	
	fig,axs = plt.subplots(2,1, figsize=FIG_SIZE_2x1)
	axs = axs.flatten()
	#
	axs[0].plot(parameter_space, calvo_price_slopes, label='Standard model', color='green')
	axs[0].plot(parameter_space, real_rigid_price_slopes, label='Real rigidities included', color='darkorange')
	axs[0].plot(parameter_space, phil_price_slopes, label='Supply-side effect included', color='blue')
	axs[0].set_ylabel('Phillips curve slope')
	#axs[0].set_yscale('log')
	axs[0].set_title('Price Phillips curve')
	if reverse:
		axs[0].set_xlim(axs[0].get_xlim()[::-1])
	axs[0].legend(loc='lower left')
	#
	if overall_flattening:
		axs[1].plot(parameter_space, overall_flattening_price, label='Overall flattening', color='black')
	axs[1].plot(parameter_space, real_rigid_contrib_price, label='Real rigidities', color='darkorange')
	axs[1].plot(parameter_space, misalloc_contrib_price, label='Supply-side effect', color='blue')
	axs[1].set_xlabel(xlabel)
	axs[1].set_ylabel('Flattening from channel')
	if reverse:
		axs[1].set_xlim(axs[1].get_xlim()[::-1])
	axs[1].legend()

	plt.tight_layout(rect=[0, 0, 1, 0.95])


